from __future__ import division

import numpy
import matplotlib.pyplot as pyplot
from matplotlib.collections import LineCollection
from matplotlib.colors import LogNorm
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.patches as patches
import os
import scipy.interpolate as interp
import scipy.optimize as optimise
import matplotlib.gridspec as gridspec
from jqc import jqc_plot
from mpl_toolkits.axes_grid1.inset_locator import inset_axes,zoomed_inset_axes,mark_inset
from matplotlib.transforms import TransformedBbox
from mpl_toolkits.axes_grid1.inset_locator import BboxPatch, BboxConnector,BboxConnectorPatch
from scipy.misc import derivative

jqc_plot.plot_style("normal")

cwd = os.path.dirname(os.path.abspath(__file__))

mI_Rb = 3/2
mI_Cs = 7/2
MN = 0

################################START OF My Functions###########################################
#################Define best way to plot the colormapped lines##############

colour_dict_twk_blue = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,124.0/255.0,124.0/255.0),
            (0.66,0.0,0.0),
            (1.0,0.0,0.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,154.0/255.0,154.0/255.0),
            (0.66,70.0/255.0,70.0/255.0),
            (1.0,32.0/255.0,32.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,148.0/255.0,148/255.0),
            (0.66,127.0/255.0,127.0/255.0),
            (1.0,58.0/255.0,58.0/255.0)]
}

colour_dict_twk_red = {
    "red" : [(0.0,244.0/255.0,244.0/255.0),
            (0.33,229.0/255.0,229.0/255.0),
            (0.66,214.0/255.0,214.0/255.0),
            (1.0,170.0/255.0,170.0/255.0)] ,
    "green" : [(0.0,234.0/255.0,234.0/255.0),
            (0.33,177.0/255.0,177.0/255.0),
            (0.66,120.0/255.0,120.0/255.0),
            (1.0,43.0/255.0,43.0/255.0)]  ,
    "blue" : [(0.0,168.0/255.0,168.0/255.0),
            (0.33,145.0/255.0,145/255.0),
            (0.66,122.0/255.0,122.0/255.0),
            (1.0,74.0/255.0,74.0/255.0)]

}

colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",colour_dict_twk_blue_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_blue)

colour_dict_twk_red_alpha = colour_dict_twk_red.copy()
colour_dict_twk_red_alpha['alpha'] = ((0.0, 0.0,0.0),
                #   (0.25,1.0, 1.0),
                   (0.5, 1.0, 1.0),
                #   (0.75,1.0, 1.0),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_red = LinearSegmentedColormap("RbCs_map_tweak_red",colour_dict_twk_red_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_red)


def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
                norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
                legend=False,ax=None):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''
    if ax == None:
        ax = pyplot.gca()

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  # to check for numerical input -- this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm, linewidth=linewidth,zorder=1.25)

    ax.add_collection(lc)

    return lc

def RescaleField(line,x,factor):
    #interpolate the transition map as a function of electric field
    line_fn = interp.interp1d(line[:,0],line[:,1])
    return line_fn(factor*x)


def ExtractData(FileStart,groundstate=None):
    #define some file names for determining what states are interesting
    FileEnd = "b0001_coef.dat"
    FileMax = 81 # change this later to search in directory and find no.
    # start by building filename string
    Filename = FileStart+str(1).zfill(4)+FileEnd

    #open an array with string datatypes, delimiter = "\n"
    Data = numpy.genfromtxt(Filename, dtype='str', delimiter='none')
    # iterate over all lines in Jesus's code
    for i in range(len(Data)):
        if Data[i][0:3] == "Num":
            StateNum = int(Data[i][4:])
            break
    #Array to store data
    #Energy, N, MN, Coeffpi, Coeffs-, Coeffs+
    Lines = numpy.zeros((FileMax, StateNum+2, 6))
    # now loop over the number of files that are in the directory
    for iE in range(0, FileMax):
        #build the filename programatically
        Filename = FileStart+str(iE+1).zfill(4)+FileEnd

        #import data in same format as above
        Data = numpy.genfromtxt(Filename, dtype='str', delimiter='none')

        iL = 1
        for i in range(len(Data)):
            # iterating over each line
            if Data[i][0] == "B":
                #extract values for the fields
                B = Data[i][2:20]
                E = Data[i][30:48]
                Lines[iE, 1, :] = float(B.replace('D', 'E'))*1e4 #Gauss
                Lines[iE, 0, :] = float(E.replace('D', 'E'))/1e5 #V/cm

            if Data[i][0:3] == "Eig":
                #reset local variables and increase the iL counter
                CoeffPi = 0
                CoeffSm = 0
                CoeffSp = 0
                iL += 1
                Lines[iE, iL, 0] = float(Data[i][11:])/1e3 #MHz

            if Data[i][0] == "J":
                #if Data[i+1][:]:
                mI_Rb_i = float(Data[i][32:43])
                mI_Cs_i = float(Data[i][47:59])

                if abs(mI_Rb_i - mI_Rb) <= 0.1 and abs(mI_Cs_i - mI_Cs) <= 0.1:

                    MN_i = float(Data[i][17:28])
                    Lines[iE, iL, 1] = float(Data[i][2:12]) #N
                    Lines[iE, iL, 2] = MN_i #MN
                    if abs(MN_i - MN) <= 0.1:
                        CoeffPi += float(Data[i-1][-15:-1]) #coeff for pi transitions
                    elif abs(MN_i - (MN-1)) <= 0.1:
                        CoeffSm += float(Data[i-1][-15:-1]) #coeff for sigma- transitions
                    elif abs(MN_i - (MN+1)) <= 0.1:
                        CoeffSp += float(Data[i-1][-15:-1])#coeff for sigma+ transitions
                    Lines[iE, iL, 3] = CoeffPi
                    Lines[iE, iL, 4] = CoeffSm
                    Lines[iE, iL, 5] = CoeffSp
    return Lines

#############################SCRIPT RUN################################
Data_exp = numpy.genfromtxt(cwd+"\\Experiment\\Data_2_3_19.csv",delimiter=',')[:,:]
E_exp = (2*Data_exp[:,0])

fpath = cwd + "\\FineJesus\\"

data = numpy.genfromtxt(fpath+"MFp05_1\\results.dat")
GS = data[:,2]

Efit = data[:,0]*1e-2 #V/m to V/cm

locs = numpy.where(Efit<40)
fitting_LineA = data[:,3]
fitting_lineB = data[:,5]

fitting_line = fitting_lineB.copy()
fitting_line[locs] = fitting_LineA[locs]

fitting_line = fitting_line-GS

Grid = gridspec.GridSpec(1,4,width_ratios=[1,0.02,0.05,0.05])

figure= pyplot.figure("FINE DC")
ax = figure.add_subplot(Grid[0])
ax2 = inset_axes(ax,width="40.%",height="40%",loc=2,
                   bbox_to_anchor=(0.2, 0, 1, 1),
                   bbox_transform=ax.transAxes)

cax1 = figure.add_subplot(Grid[2])
cax2 = figure.add_subplot(Grid[3])
Trans_fitting = numpy.vstack([Efit,fitting_line*1e-3]).T

fitting_fn = lambda x,a :RescaleField(Trans_fitting,x,a)

fit_line = interp.interp1d(Trans_fitting[:,0],Trans_fitting[:,1])

F = fit_line(100)
dF = derivative(fit_line,100)
ddF = derivative(lambda x: derivative(fit_line,x),100)
print(F,dF,ddF)

curve,cov = optimise.curve_fit(fitting_fn,E_exp[:-3],Data_exp[:-3,1],p0=[.1508],sigma=Data_exp[:-3,2],absolute_sigma=True,bounds=([.100],[.175]))
E_exp = E_exp*curve[0]
print(curve,numpy.sqrt(cov[0,0]))

dataset_high = cwd+"\\nmax5_Ep\\nmax5_Ep\\"
dataset_low =  cwd+"\\nmax5_small_E\\"



for MF in range(-6,7,1):

    if MF<0:
        string = "MFm"+str(abs(MF)).zfill(2)+"_1"
    else:
        string = "MFp"+str(abs(MF)).zfill(2)+"_1"
    file = string +"\\results.dat"
    data = numpy.genfromtxt(fpath+file)

    lines = data[:,2:]-numpy.tile(GS,(data[:,2:].shape[1],1)).T

    #load old data,
    #load the data from this file path using ExtractData
    filepath_H = dataset_high+string+"\\uncoupled_e"
    Dat_H =ExtractData(filepath_H)

    filepath_L = dataset_low+string+"\\uncoupled_e"
    Dat_L =ExtractData(filepath_L)

    Dat = numpy.concatenate([Dat_L,Dat_H],axis=0)
    #plot data using coloured lines according to MN
    Dat = Dat[:,::-1,:]
    ax.plot(data[:,0]*1e-2,lines*1e-3,color=(244.0/255.0, 234.0/255.0, 168.0/255.0), zorder=0)
    ax2.plot(data[:,0]*1e-2,lines*1e-3,color=(244.0/255.0, 234.0/255.0, 168.0/255.0), zorder=0)

    for k in range(len(Dat[0,:,0])-2):
        Pi = interp.interp1d(Dat[:, -1, 0],Dat[:,k,3])
        Sm = interp.interp1d(Dat[:, -1, 0],Dat[:,k,4])
        Sp = interp.interp1d(Dat[:, -1, 0],Dat[:,k,5])
        vmin=1e-3

        cl1 = colorline(data[:,0]*1e-2,lines[:,k]*1e-3,Pi(data[:,0]*1e-5),cmap='RbCs_map_tweak_blue',\
                     norm=LogNorm(vmin,vmax=1),linewidth=2.0,ax=ax)

        colorline(data[:,0]*1e-2,lines[:,k]*1e-3,Pi(data[:,0]*1e-5),cmap='RbCs_map_tweak_blue',\
                  norm=LogNorm(vmin,vmax=1),linewidth=2.0,ax=ax2)

        colorline(data[:,0]*1e-2,lines[:,k]*1e-3,Sm(data[:,0]*1e-5),cmap='RbCs_map_tweak_red',\
                     norm=LogNorm(vmin,vmax=1),linewidth=2.0,ax=ax)

        colorline(data[:,0]*1e-2,lines[:,k]*1e-3,Sm(data[:,0]*1e-5),cmap='RbCs_map_tweak_red',\
                     norm=LogNorm(vmin,vmax=1),linewidth=2.0,ax=ax2)

        cl2 = colorline(data[:,0]*1e-2,lines[:,k]*1e-3,Sp(data[:,0]*1e-5),cmap='RbCs_map_tweak_red',\
                     norm=LogNorm(vmin,vmax=1),linewidth=2.0,ax=ax)

        colorline(data[:,0]*1e-2,lines[:,k]*1e-3,Sp(data[:,0]*1e-5),cmap='RbCs_map_tweak_red',\
                     norm=LogNorm(vmin,vmax=1),linewidth=2.0,ax=ax2)


ax.errorbar(E_exp,Data_exp[:,1],yerr=Data_exp[:,2],fmt='o',color='k')
ax2.errorbar(E_exp,Data_exp[:,1],yerr=Data_exp[:,2],fmt='o',color='k')


cbar1 = pyplot.colorbar(cl1,cax=cax1, pad=-.08)
#only include a colourbar on the first or last line else you get overlapping
#cbar1.set_ticks([1e-3, 1e-2, 1e-1, 1e-0])
cax1.axes.get_yaxis().set_ticklabels([])
cbar1.ax.set_title("$z$",color=jqc_plot.colours['blue'])

cbar2 = pyplot.colorbar(cl2,cax=cax2, pad=-.08)
#only include a colourbar on the first or last line else you get overlapping
cbar2.set_label('Relative Transition Strength')
#cbar2.set_ticks([1e-3, 1e-2, 1e-1, 1e-0])
cbar2.ax.set_title("$y$",color=jqc_plot.colours['red'])

ax.set_xlabel("Electric Field, $E_z$ (V$\\,$cm$^{-1}$)")
ax.set_ylabel("Transition Frequency, $f$ (MHz)")

ax.set_xlim(-2,500)
ax.set_ylim(979,1030)

ax2.set_ylabel("$f$ (MHz)")
ax2.set_xlabel("$E_z$ (V$\\,$cm$^{-1}$)")

ax2.set_xlim(0,100)
ax2.set_ylim(980,982)

#mark_inset(ax,ax2,loc1=2,loc2=1,fc='none',ec=jqc_plot.colours['grayblue'])

x1,x2 = ax2.get_xlim()
y1,y2 = ax2.get_ylim()
w,h = x2-x1,y2-y1

Rect = patches.Rectangle((x1,y1),width=w,height=h,transform=ax.transData,
                        ec=jqc_plot.colours['grayblue'],fc='none',
                        zorder=1.45,lw=1.5)
ax.add_patch(Rect)

Rect_bbox = TransformedBbox(Rect.get_bbox(),ax.transData)
#draw a polygon to connect the highlight to the zoomed axis.
p1 = BboxConnectorPatch(Rect_bbox,ax2.bbox, loc1a=2, loc2a=3,loc1b=1,
 						loc2b=4, ec=jqc_plot.colours['grayblue'],zorder = 1.5,lw=1.5)

#ax.add_patch(p1)

pyplot.tight_layout()
pyplot.subplots_adjust(wspace=0)

pyplot.savefig(cwd+"\\output\\DC_Stark.pdf")
#pyplot.savefig(cwd+"\\output\\DC_Stark.png")

pyplot.show()
